from datasets import load_dataset
import json
import random
# import nltk
import re

# import tiktoken 
# nltk.download('punkt')

random.seed(42)

def extract_think_sections(text: str):
    think_match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
    if think_match:
        think_content = think_match.group(1).strip()
        end_pos = think_match.end()
        post_think_content = text[end_pos:].strip()
    else:
        think_content = None
        post_think_content = text.strip()
        raise ValueError
    return think_content, post_think_content

def extract_think_and_solution_V2(text: str):
    pattern = (
        r"<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>\s*"
        r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>"
    )
    match = re.search(pattern, text, re.DOTALL)
    if match:
        think_content = match.group(1).strip()
        post_think_content = match.group(2).strip()
    else:
        think_content = None
        post_think_content = text.strip()
        raise ValueError("Missing required <|begin_of_thought|> or <|begin_of_solution|> blocks.")
    return think_content, post_think_content


from transformers import AutoTokenizer

ds_tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")

def split_by_tokens(text: str, chunk_size: int = 100):

    tokens = ds_tokenizer.encode(text, add_special_tokens=False)

    chunks = []
    for i in range(0, len(tokens), chunk_size):
        chunk_tokens = tokens[i:i + chunk_size]
        chunk_text = ds_tokenizer.decode(chunk_tokens, skip_special_tokens=True)
        chunks.append(chunk_text.strip())
    return chunks

# def split_by_tokens(text: str, chunk_size: int = 100, model_name: str = "gpt2"):
#     enc = tiktoken.get_encoding(model_name)
#     tokens = enc.encode(text)

#     chunks = []
#     for i in range(0, len(tokens), chunk_size):
#         chunk_tokens = tokens[i:i+chunk_size]
#         chunk_text = enc.decode(chunk_tokens)
#         chunks.append(chunk_text.strip())
#     return chunks

if __name__ == "__main__":
    # name = "OpenR1-Math-220k"
    # dataset = load_dataset('/workspace/0407_nips/data_preprocess/OpenR1-Math-220k/data')["train"]

    # name = "reasoning-v1-20m"
    # dataset = load_dataset('/workspace/0407_nips/data_preprocess/reasoning-v1-20m/data')["train"]


    # name = "OpenThoughts-114k-math"
    # dataset = load_dataset('/workspace/0407_nips/data_preprocess/OpenThoughts-114k-math/data')["train"]

    # name = "OpenThoughts-114k-Code_decontaminated"
    # dataset = load_dataset('/workspace/0407_nips/data_preprocess/OpenThoughts-114k-Code_decontaminated/data')["train"]


    # name = "Medical-R1-Distill-Data"
    # dataset = load_dataset('/workspace/0407_nips/data_preprocess/Medical-R1-Distill-Data')["train"]

    datasets_config = {
        "OpenR1-Math-220k": "/workspace/0407_nips/data_preprocess/OpenR1-Math-220k/data",
        "reasoning-v1-20m": "/workspace/0407_nips/data_preprocess/reasoning-v1-20m/data",
        "OpenThoughts-114k-math": "/workspace/0407_nips/data_preprocess/OpenThoughts-114k-math/data",
        "OpenThoughts-114k-Code_decontaminated": "/workspace/0407_nips/data_preprocess/OpenThoughts-114k-Code_decontaminated/data",
        "Medical-R1-Distill-Data": "/workspace/0407_nips/data_preprocess/Medical-R1-Distill-Data"
    }

    block_sizes = [256, 512, 1024]
    for block_size in block_sizes:
        root_dir = f"/workspace/0407_nips/data_preprocess/0507_final_deepseek_token-{block_size}"

        for name, path in datasets_config.items():
            print(f"{name}")
            dataset = load_dataset(path)["train"]
            shuffled_dataset = dataset.shuffle(seed=42)


            # try:
            #     sampled_dataset = shuffled_dataset.select(range(40000))
            # except:
            #     print(len(dataset))
            #     sampled_dataset = shuffled_dataset
            #     print(f"skip: {name}")
            #     continue

            # remaining_indices = list(range(40000, len(shuffled_dataset)))


            # extra_indices = random.sample(remaining_indices, 3000)

            # sampled_dataset = shuffled_dataset.select(extra_indices)

            shuffled_dataset = dataset.shuffle(seed=42)


            try:
                sampled_dataset = shuffled_dataset.select(range(40000))
            except:
                print(len(dataset))
                sampled_dataset = shuffled_dataset

            
            # sampled_dataset = shuffled_dataset

            filtered_data = []
            if name == "OpenR1-Math-220k":
                for example in sampled_dataset:
                    for i in range(len(example["generations"])):
                        if example["is_reasoning_complete"][i] and example["correctness_math_verify"][i]:
                            text = example["generations"][i]
                            try:
                                think_content, post_think_content = extract_think_sections(text)
                                filtered_data.append({
                                    "reasoning": think_content,
                                    "reasoning_sentences": split_by_tokens(think_content, chunk_size=block_size),
                                    "answer": post_think_content,
                                    "question": example["problem"],
                                    "solution": example["solution"],
                                    "extracted_answer": example["answer"],
                                    "uuid": example["uuid"],
                                })
                                break
                            except ValueError:
                                continue 
            elif name == "reasoning-v1-20m":
                for example in sampled_dataset:
                    if "</think>" in example["response"]:
                        try:
                            
                            think_content, post_think_content = extract_think_sections(example["response"])
                            filtered_data.append({
                                "reasoning": think_content,
                                "reasoning_sentences": split_by_tokens(think_content, chunk_size=block_size), 
                                "answer": post_think_content,
                                "question": example["prompt"],
                                "solution": "",
                                "extracted_answer": "",
                                "uuid": "",
                            })
                        except ValueError:
                            print("wrong")
            elif name == "OpenThoughts-114k-math":
                for example in sampled_dataset:
                    # if "</think>" in example["response"]:
                    try:
                        text = example["conversations"][1]["value"]
                        think_content, post_think_content = extract_think_and_solution_V2(text)
                        filtered_data.append({
                            "reasoning": think_content,
                            "reasoning_sentences": split_by_tokens(think_content, chunk_size=block_size),
                            "answer": post_think_content,
                            "question": example["problem"],
                            "solution": example["solution"],
                            "extracted_answer": "",
                            "uuid": example["source"],
                        })
                    except ValueError:
                        print("wrong") 
            elif name == "OpenThoughts-114k-Code_decontaminated":

                for example in sampled_dataset:
                    try:
                        # text = example["conversations"][1]["value"]
                        # think_content, post_think_content = extract_think_and_solution_V2(text)
                        filtered_data.append({
                            "reasoning": example["deepseek_reasoning"],
                            "reasoning_sentences": split_by_tokens(example["deepseek_reasoning"], chunk_size=block_size),
                            "answer": example["deepseek_solution"],
                            "question": example["problem"],
                            "solution": "",
                            "extracted_answer": "",
                            "uuid": example["source"],
                        })
                    except ValueError:
                        print("wrong")
            elif name == "Medical-R1-Distill-Data":

                for example in sampled_dataset:
                    try:
                        # text = example["conversations"][1]["value"]
                        # think_content, post_think_content = extract_think_and_solution_V2(text)
                        filtered_data.append({
                            "reasoning": example["reasoning (reasoning_content)"],
                            "reasoning_sentences": split_by_tokens(example["reasoning (reasoning_content)"], chunk_size=block_size),
                            "answer": example["response (content)"],
                            "question": example["question"],
                            "solution": "",
                            "extracted_answer": "",
                            "uuid": "",
                        })
                    except ValueError:
                        print("wrong")


            print(f"筛选后数据量: {len(filtered_data)}")


            sampled_filtered_data = random.sample(filtered_data, 1000)
            with open(f"output_test_{name}_1000.json", "w", encoding="utf-8") as f:
                json.dump(sampled_filtered_data, f, ensure_ascii=False, indent=2)

            # with open(f"output_test_{name}_1000.json", "r", encoding="utf-8") as f:
            #     sampled_filtered_data = json.load(f)


            remaining_filtered_data = [d for d in filtered_data if d not in sampled_filtered_data]
            if len(remaining_filtered_data) >= 20000:
                sampled_training_data = random.sample(remaining_filtered_data, 20000)
            else:
                sampled_training_data = remaining_filtered_data
                print(f"only {len(remaining_filtered_data)}")

            with open(f"output_train_{name}_10000.json", "w", encoding="utf-8") as f:
                json.dump(sampled_training_data, f, ensure_ascii=False, indent=2)
